import numpy as np
from math import *
from env import dynamics
import time
import torch
import torch.nn as nn
import torch.nn.functional as F
torch.manual_seed(2)
np.random.seed(2)

def choose_action(policy_distribution):  # distribution is 9x1
  choice=np.random.uniform()
  sum_value=0.0
  for a in range(num_action):
    sum_value=sum_value+policy_distribution[a]
    if sum_value>=choice:
      return a

def trial(initial_state,policy1,policy2,num_action):
  trajectory=[]
  state=initial_state
  for i in range(35):
    policy1_distribution=policy1[state.item(0)][state.item(1)][:]
    action1=choose_action(policy1_distribution)
    next_state1=dynamics(state[0:2],np.mat([action1]).T)
    policy2_distribution=policy2[state.item(2)][state.item(3)][:]
    action2=choose_action(policy2_distribution)
    next_state2=dynamics(state[2:4],np.mat([action2]).T)
    
    trajectory.append([state.item(0),state.item(1),state.item(2),state.item(3),action1,action2])
    state=np.copy(np.vstack((next_state1,next_state2)))
  return trajectory

def soft_policy(Q_matrix,V_matrix,num_action):
  distribution=np.zeros((10,13,num_action))
  distribution=distribution.astype(np.object)
  for x in range(10):
    for y in range(13):
      for a in range(num_action):
        distribution[x][y][a]=exp(Q_matrix[x][y][a])/exp(V_matrix[x][y])
  return distribution

def soft_Q_matrix_function(gamma,reward_matrix,cost_matrix,V_matrix,num_action):
  Q_matrix=np.zeros((10,13,num_action))
  Q_matrix=Q_matrix.astype(np.object)
  for x in range(10):
    for y in range(13):
      for a in range(num_action):
        next_state=dynamics(np.mat([x,y]).T,np.mat([a]).T)
        value=V_matrix[next_state.item(0)][next_state.item(1)]
        Q_matrix[x][y][a]=reward_matrix[x,y]-cost_matrix[x,y]+gamma*value
  return Q_matrix

  
def soft_V_matrix_funciton(Q_matrix,num_action):
  V_matrix=np.zeros((10,13))
  V_matrix=V_matrix.astype(np.object)
  for x in range(10):
    for y in range(13):
      value=0.0
      for a in range(num_action):
        value=value+exp(Q_matrix[x][y][a])
      V_matrix[x][y]=log(value)
  return V_matrix

def calculate_policy(omega,gamma,num_action):
  reward1_matrix=np.zeros((10,13))
  reward1_matrix=reward1_matrix.astype(np.object)
  reward2_matrix=np.zeros((10,13))
  reward2_matrix=reward2_matrix.astype(np.object)
  reward1_matrix[9,12]=5
  reward2_matrix[9,0]=5

  cost_matrix=100*omega

  soft_V1_matrix=np.zeros((10,13))
  soft_V1_matrix=soft_V1_matrix.astype(np.object)
  soft_Q1_matrix=np.copy(soft_Q_matrix_function(gamma,reward1_matrix,cost_matrix,soft_V1_matrix,num_action))
  new_soft_V1_matrix=np.copy(soft_V_matrix_funciton(soft_Q1_matrix,num_action))
  soft_V2_matrix=np.zeros((10,13))
  soft_V2_matrix=soft_V2_matrix.astype(np.object)
  soft_Q2_matrix=np.copy(soft_Q_matrix_function(gamma,reward2_matrix,cost_matrix,soft_V2_matrix,num_action))
  new_soft_V2_matrix=np.copy(soft_V_matrix_funciton(soft_Q2_matrix,num_action))

  for m in range(50):
    #print(max_value3)
    soft_V1_matrix=np.copy(new_soft_V1_matrix)
    soft_Q1_matrix=np.copy(soft_Q_matrix_function(gamma,reward1_matrix,cost_matrix,soft_V1_matrix,num_action))
    new_soft_V1_matrix=np.copy(soft_V_matrix_funciton(soft_Q1_matrix,num_action))
    soft_V2_matrix=np.copy(new_soft_V2_matrix)
    soft_Q2_matrix=np.copy(soft_Q_matrix_function(gamma,reward2_matrix,cost_matrix,soft_V2_matrix,num_action))
    new_soft_V2_matrix=np.copy(soft_V_matrix_funciton(soft_Q2_matrix,num_action))

  policy1=np.copy(soft_policy(soft_Q1_matrix,new_soft_V1_matrix,num_action))
  policy2=np.copy(soft_policy(soft_Q2_matrix,new_soft_V2_matrix,num_action))
  return policy1, policy2


def constraint_map(number_trials,trajectories):
  constraint_map=np.zeros((10,13))
  constraint_map=constraint_map.astype(np.object)
  for i in range(number_trials):
    for j in range(35):
      constraint_map[int(trajectories[35*i+j,0]),int(trajectories[35*i+j,1])]=constraint_map[int(trajectories[35*i+j,0]),int(trajectories[35*i+j,1])]+1.0
      constraint_map[int(trajectories[35*i+j,2]),int(trajectories[35*i+j,3])]=constraint_map[int(trajectories[35*i+j,2]),int(trajectories[35*i+j,3])]+1.0
  return constraint_map/number_trials

def false_positive_negative_rate(omega):
  positive=0.0
  false_positive=0.0
  for x in range(10):
    for y in range(13):
      if omega[x,y]>0.0:
        if x>=3 and x<=9 and y>=2 and y<=10:
          positive=positive+1.0
        elif x==0 and y>=1 and y<=5:
          positive=positive+1.0
        elif x==0 and y>=7 and y<=11:
          positive=positive+1.0
        elif x==8 and y==0:
          positive=positive+1.0
        elif x==2 and y==1:
          positive=positive+1.0
        elif x==3 and y==11:
          positive=positive+1.0
        elif x==7 and y==12:
          positive=positive+1.0
        else:
          false_positive=false_positive+1.0
  return false_positive/53.0, (77.0-positive)/77.0

def obstacle_collision(x,y):
  if x>=3 and x<=9 and y>=2 and y<=10:
    return True
  elif x==0 and y>=1 and y<=5:
    return True
  elif x==0 and y>=7 and y<=11:
    return True
  elif x==8 and y==0:
    return True
  elif x==2 and y==1:
    return True
  elif x==3 and y==11:
    return True
  elif x==7 and y==12:
    return True
  else:
    return False

def constraint_violation_rate(number_trials,trajectories):
  violation_list=[]
  for i in range(number_trials):
    violation=0.0
    for j in range(35):
      if obstacle_collision(trajectories[35*i+j,0],trajectories[35*i+j,1]):
        violation=violation+1.0
        break
    for j in range(35):
      if obstacle_collision(trajectories[35*i+j,2],trajectories[35*i+j,3]):
        violation=violation+1.0
        break
    violation_list.append(violation/2)
  return sum(violation_list)/len(violation_list)

def success_rate(number_trials,trajectories):
  success_list=[]
  for i in range(number_trials):
    success=0.0
    for j in range(35):
      if obstacle_collision(trajectories[35*i+j,0],trajectories[35*i+j,1]):
        break
      elif trajectories[35*i+j,0]==9 and trajectories[35*i+j,1]==12:
        success=success+1.0
        break
    for j in range(35):
      if obstacle_collision(trajectories[35*i+j,2],trajectories[35*i+j,3]):
        break 
      elif trajectories[35*i+j,2]==9 and trajectories[35*i+j,3]==0:
        success=success+1.0 
        break
    success_list.append(success/2)
  return sum(success_list)/len(success_list)

def likelihood_function(policy1,policy2,num_trials,gamma,num_iteration):
  a=np.loadtxt("expert_trajectory_file.txt",dtype=float)
  trajectories=a.reshape(35*num_trials,6)
  likelihood=0.0
  for i in range(2*num_iteration):
    single_trajectory=trajectories[35*i:35*(i+1),:]
    for j in range(35):
      x1=int(single_trajectory[j,0])
      y1=int(single_trajectory[j,1])
      x2=int(single_trajectory[j,2])
      y2=int(single_trajectory[j,3])
      a1=int(single_trajectory[j,4])
      a2=int(single_trajectory[j,5])     
      if policy1[x1,y1,a1]!=0:
        likelihood=likelihood+gamma**j*log(policy1[x1,y1,a1])
      if policy2[x2,y2,a2]!=0:
        likelihood=likelihood+gamma**j*log(policy2[x2,y2,a2])
  return likelihood

num_action=4
gamma=1.0
num_trials=50
initial_state=np.mat([9,0,9,12]).T

def experiment():  
  false_positive_list=[]
  false_negative_list=[]
  constraint_violation_list=[]
  success_list=[]
  omega=np.zeros((10,13))
  omega=omega.astype(np.object)
  for i in range(20):
    print('online iteration', i+1)
    print('omega', omega)
    likelihood_matrix=-1000000000.0*np.ones((10,13))
    likelihood_matrix=likelihood_matrix.astype(np.object)
    for x in range(10):
      for y in range(13):
        if omega[x,y]==1.0:
          continue
        else:
          omega[x,y]=1.0
          policy1,policy2=calculate_policy(omega,gamma,num_action)
          likelihood_matrix[x,y]=likelihood_function(policy1,policy2,num_trials,gamma,i)
          omega[x,y]=0.0
    max_index=np.argmax(likelihood_matrix)
    x_index=int(max_index/13)
    y_index=max_index-13*x_index

    omega[x_index,y_index]=1.0
    policy1,policy2=calculate_policy(omega,gamma,num_action)

    trajectory_file=open("learner_trajectory_file.txt","w")
    for j in range(num_trials):
      trajectory=np.copy(trial(initial_state,policy1,policy2,num_action))
      for entry in trajectory:
        np.savetxt(trajectory_file,entry)
    trajectory_file.close()
    b=np.loadtxt("learner_trajectory_file.txt",dtype=float)
    learner_trajectories=b.reshape(35*num_trials,6)
    false_positive_rate,false_negative_rate=false_positive_negative_rate(omega)
    constraint_violation_mean=constraint_violation_rate(num_trials,learner_trajectories)
    success_mean=success_rate(num_trials,learner_trajectories) 

    false_positive_list.append(false_positive_rate)
    false_negative_list.append(false_negative_rate)
    constraint_violation_list.append(constraint_violation_mean)
    success_list.append(success_mean)

    print(trajectory)

    print('false positive rate', false_positive_rate)
    print('false negative rate', false_negative_rate)
    print('constraint violation', constraint_violation_mean)
    print('success_rate', success_mean)

  return false_positive_list,false_negative_list,constraint_violation_list,success_list

all_false_positive_list=[]
all_false_negative_list=[]
all_constraint_violation_list=[]
all_success_list=[]
num_experiment=2

start_time=time.time()
for number in range(num_experiment):
  false_positive_list,false_negative_list,constraint_violation_list,success_list=experiment()
  all_false_positive_list.append(false_positive_list)
  all_false_negative_list.append(false_negative_list)
  all_constraint_violation_list.append(constraint_violation_list)
  all_success_list.append(success_list)
end_time=time.time()
print('time cost for one experiment',(end_time-start_time)/num_experiment)

false_positive_mean_list=[[0.0]]
false_positive_sd_list=[[0.0]]
false_negative_mean_list=[[1.0]]
false_negative_sd_list=[[0.0]]
constraint_violation_mean_list=[[1.0]]
constraint_violation_sd_list=[[0.0]]
success_mean_list=[[0.0]]
success_sd_list=[[0.0]]

for i in range(20):
  positive_list=[]
  negative_list=[]
  violation_list=[]
  succ_list=[]
  for j in range(num_experiment):
    positive_list.append(all_false_positive_list[j][i])
    negative_list.append(all_false_negative_list[j][i])
    violation_list.append(all_constraint_violation_list[j][i])
    succ_list.append(all_success_list[j][i])
  false_positive_mean_list.append([sum(positive_list)/len(positive_list)])
  false_positive_sd_list.append([sqrt(np.var(positive_list))])
  false_negative_mean_list.append([sum(negative_list)/len(negative_list)])
  false_negative_sd_list.append([sqrt(np.var(negative_list))])
  constraint_violation_mean_list.append([sum(violation_list)/len(violation_list)])
  constraint_violation_sd_list.append([sqrt(np.var(violation_list))])
  success_mean_list.append([sum(succ_list)/len(succ_list)])
  success_sd_list.append([sqrt(np.var(succ_list))])

false_positive_mean_file=open("greedy_false_positive_mean_file.txt","w")
for entry in false_positive_mean_list:
  np.savetxt(false_positive_mean_file,entry)
false_positive_mean_file.close()

false_negative_mean_file=open("greedy_false_negative_mean_file.txt","w")
for entry in false_negative_mean_list:
  np.savetxt(false_negative_mean_file,entry)
false_negative_mean_file.close()

constraint_violation_mean_file=open("greedy_constraint_violation_mean_file.txt","w")
for entry in constraint_violation_mean_list:
  np.savetxt(constraint_violation_mean_file,entry)
constraint_violation_mean_file.close()

success_mean_file=open("greedy_success_mean_file.txt","w")
for entry in success_mean_list:
  np.savetxt(success_mean_file,entry)
success_mean_file.close()

false_positive_sd_file=open("greedy_false_positive_sd_file.txt","w")
for entry in false_positive_sd_list:
  np.savetxt(false_positive_sd_file,entry)
false_positive_sd_file.close()

false_negative_sd_file=open("greedy_false_negative_sd_file.txt","w")
for entry in false_negative_sd_list:
  np.savetxt(false_negative_sd_file,entry)
false_negative_sd_file.close()

constraint_violation_sd_file=open("greedy_constraint_violation_sd_file.txt","w")
for entry in constraint_violation_sd_list:
  np.savetxt(constraint_violation_sd_file,entry)
constraint_violation_sd_file.close()

success_sd_file=open("greedy_success_sd_file.txt","w")
for entry in success_sd_list:
  np.savetxt(success_sd_file,entry)
success_sd_file.close()

























